from tqdm import tqdm
from data_utils import DataManager
from data_utils.datasets_mm_guide import MultiModalDatasetGuide, MultiModalMAEMaskGuide

manager = DataManager(('3d_thick', '3d_mid', '3d_thin'), total_dim=6144, seq_len=2048)
dataset = MultiModalDatasetGuide(data_manager=manager, )
mask = MultiModalMAEMaskGuide(data_manager=manager, mask_ratio=0.75)

for idx in tqdm(range(len(dataset))):
    data = dataset[idx]
    data = mask(*data)

print('Check Complete')
